{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TensorFlow2教程-Keras概述\n",
    "\n",
    "Keras 是一个用于构建和训练深度学习模型的高阶 API。它可用于快速设计原型、高级研究和生产。\n",
    "\n",
    "Keras的3个优点：\n",
    "方便用户使用、模块化和可组合、易于扩展\n",
    "\n",
    "\n",
    "##  1 导入tf.keras\n",
    "TensorFlow2推荐使用tf.keras构建网络，常见的神经网络都包含在tf.keras.layer中(最新的tf.keras的版本可能和keras不同)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/doit/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "print(tf.__version__)\n",
    "print(tf.keras.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 构建简单模型\n",
    "### 2.1 模型堆叠\n",
    "最常见的模型类型是层的堆叠：tf.keras.Sequential 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential()\n",
    "model.add(layers.Dense(32, activation='relu'))\n",
    "model.add(layers.Dense(32, activation='relu'))\n",
    "model.add(layers.Dense(10, activation='softmax'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 网络配置\n",
    "\n",
    "tf.keras.layers中主要的网络配置参数如下：\n",
    "\n",
    "activation：设置层的激活函数。此参数可以是函数名称字符串，也可以是函数对象。默认情况下，系统不会应用任何激活函数。\n",
    "\n",
    "kernel_initializer 和 bias_initializer：创建层权重（核和偏置）的初始化方案。此参数是一个名称或可调用的函数对象，默认为 \"Glorot uniform\" 初始化器。\n",
    "\n",
    "kernel_regularizer 和 bias_regularizer：应用层权重（核和偏置）的正则化方案，例如 L1 或 L2 正则化。默认情况下，系统不会应用正则化函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.layers.core.Dense at 0x7f0d70476518>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layers.Dense(32, activation='sigmoid')\n",
    "layers.Dense(32, activation=tf.sigmoid)\n",
    "layers.Dense(32, kernel_initializer='orthogonal')\n",
    "layers.Dense(32, kernel_initializer=tf.keras.initializers.glorot_normal)\n",
    "layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(0.01))\n",
    "layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l1(0.01))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 训练和评估\n",
    "### 3.1 设置训练流程\n",
    "构建好模型后，通过调用 compile 方法配置该模型的学习流程：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential()\n",
    "model.add(layers.Dense(32, activation='relu'))\n",
    "model.add(layers.Dense(32, activation='relu'))\n",
    "model.add(layers.Dense(10, activation='softmax'))\n",
    "model.compile(optimizer=tf.keras.optimizers.Adam(0.001),\n",
    "             loss=tf.keras.losses.categorical_crossentropy,\n",
    "             metrics=[tf.keras.metrics.categorical_accuracy])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 输入Numpy数据\n",
    "对于小型数据集，可以使用Numpy构建输入数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 1000 samples, validate on 200 samples\n",
      "Epoch 1/10\n",
      "1000/1000 [==============================] - 1s 503us/sample - loss: 11.8979 - categorical_accuracy: 0.0920 - val_loss: 12.1516 - val_categorical_accuracy: 0.1550\n",
      "Epoch 2/10\n",
      "1000/1000 [==============================] - 0s 21us/sample - loss: 12.2874 - categorical_accuracy: 0.0910 - val_loss: 12.9158 - val_categorical_accuracy: 0.1150\n",
      "Epoch 3/10\n",
      "1000/1000 [==============================] - 0s 31us/sample - loss: 13.3758 - categorical_accuracy: 0.0940 - val_loss: 14.4959 - val_categorical_accuracy: 0.0900\n",
      "Epoch 4/10\n",
      "1000/1000 [==============================] - 0s 28us/sample - loss: 15.4028 - categorical_accuracy: 0.0920 - val_loss: 17.2651 - val_categorical_accuracy: 0.1000\n",
      "Epoch 5/10\n",
      "1000/1000 [==============================] - 0s 24us/sample - loss: 18.6433 - categorical_accuracy: 0.0930 - val_loss: 21.0552 - val_categorical_accuracy: 0.1000\n",
      "Epoch 6/10\n",
      "1000/1000 [==============================] - 0s 31us/sample - loss: 22.0638 - categorical_accuracy: 0.0930 - val_loss: 23.7690 - val_categorical_accuracy: 0.1000\n",
      "Epoch 7/10\n",
      "1000/1000 [==============================] - 0s 25us/sample - loss: 24.8034 - categorical_accuracy: 0.0930 - val_loss: 27.7582 - val_categorical_accuracy: 0.1000\n",
      "Epoch 8/10\n",
      "1000/1000 [==============================] - 0s 24us/sample - loss: 30.0158 - categorical_accuracy: 0.0920 - val_loss: 34.5013 - val_categorical_accuracy: 0.1000\n",
      "Epoch 9/10\n",
      "1000/1000 [==============================] - 0s 27us/sample - loss: 37.2771 - categorical_accuracy: 0.0920 - val_loss: 42.4729 - val_categorical_accuracy: 0.1000\n",
      "Epoch 10/10\n",
      "1000/1000 [==============================] - 0s 30us/sample - loss: 45.5502 - categorical_accuracy: 0.0930 - val_loss: 51.9351 - val_categorical_accuracy: 0.1000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f0d702641d0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "train_x = np.random.random((1000, 72))\n",
    "train_y = np.random.random((1000, 10))\n",
    "\n",
    "val_x = np.random.random((200, 72))\n",
    "val_y = np.random.random((200, 10))\n",
    "\n",
    "model.fit(train_x, train_y, epochs=10, batch_size=100,\n",
    "          validation_data=(val_x, val_y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 tf.data输入数据\n",
    "对于大型数据集可以使用tf.data构建训练输入。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train for 30 steps, validate for 3 steps\n",
      "Epoch 1/10\n",
      "30/30 [==============================] - 0s 12ms/step - loss: 67.4564 - categorical_accuracy: 0.0948 - val_loss: 87.7801 - val_categorical_accuracy: 0.0938\n",
      "Epoch 2/10\n",
      "30/30 [==============================] - 0s 2ms/step - loss: 110.4207 - categorical_accuracy: 0.0983 - val_loss: 137.2176 - val_categorical_accuracy: 0.0729\n",
      "Epoch 3/10\n",
      "30/30 [==============================] - 0s 2ms/step - loss: 166.3288 - categorical_accuracy: 0.1026 - val_loss: 200.0635 - val_categorical_accuracy: 0.0729\n",
      "Epoch 4/10\n",
      "30/30 [==============================] - 0s 2ms/step - loss: 234.6779 - categorical_accuracy: 0.0929 - val_loss: 276.1790 - val_categorical_accuracy: 0.0938\n",
      "Epoch 5/10\n",
      "30/30 [==============================] - 0s 1ms/step - loss: 316.2306 - categorical_accuracy: 0.0855 - val_loss: 362.6940 - val_categorical_accuracy: 0.0938\n",
      "Epoch 6/10\n",
      "30/30 [==============================] - 0s 2ms/step - loss: 405.1462 - categorical_accuracy: 0.0962 - val_loss: 454.0105 - val_categorical_accuracy: 0.0938\n",
      "Epoch 7/10\n",
      "30/30 [==============================] - 0s 3ms/step - loss: 495.8588 - categorical_accuracy: 0.0897 - val_loss: 542.2636 - val_categorical_accuracy: 0.1042\n",
      "Epoch 8/10\n",
      "30/30 [==============================] - 0s 2ms/step - loss: 589.8635 - categorical_accuracy: 0.1132 - val_loss: 637.4122 - val_categorical_accuracy: 0.0833\n",
      "Epoch 9/10\n",
      "30/30 [==============================] - 0s 2ms/step - loss: 679.3736 - categorical_accuracy: 0.1079 - val_loss: 724.9229 - val_categorical_accuracy: 0.1146\n",
      "Epoch 10/10\n",
      "30/30 [==============================] - 0s 1ms/step - loss: 757.8416 - categorical_accuracy: 0.1004 - val_loss: 787.8435 - val_categorical_accuracy: 0.0938\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f0d6841fc88>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))\n",
    "dataset = dataset.batch(32)\n",
    "dataset = dataset.repeat()\n",
    "val_dataset = tf.data.Dataset.from_tensor_slices((val_x, val_y))\n",
    "val_dataset = val_dataset.batch(32)\n",
    "val_dataset = val_dataset.repeat()\n",
    "\n",
    "model.fit(dataset, epochs=10, steps_per_epoch=30,\n",
    "          validation_data=val_dataset, validation_steps=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.4 评估与预测\n",
    "评估和预测函数：tf.keras.Model.evaluate和tf.keras.Model.predict方法，都可以可以使用NumPy和tf.data.Dataset构造的输入数据进行评估和预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "1000/1 [================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 26us/sample - loss: 820.5671 - categorical_accuracy: 0.1000\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30/30 [==============================] - 0s 1ms/step - loss: 786.1435 - categorical_accuracy: 0.0969\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[786.1434997558594, 0.096875]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 模型评估\n",
    "test_x = np.random.random((1000, 72))\n",
    "test_y = np.random.random((1000, 10))\n",
    "model.evaluate(test_x, test_y, batch_size=32)\n",
    "test_data = tf.data.Dataset.from_tensor_slices((test_x, test_y))\n",
    "test_data = test_data.batch(32).repeat()\n",
    "model.evaluate(test_data, steps=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.245288   0.00419576 0.         ... 0.11711721 0.         0.        ]\n",
      " [0.29617038 0.00480588 0.         ... 0.14704514 0.         0.        ]\n",
      " [0.15192655 0.00397816 0.         ... 0.1538676  0.         0.        ]\n",
      " ...\n",
      " [0.23659316 0.00576886 0.         ... 0.13454369 0.         0.        ]\n",
      " [0.31788164 0.0062953  0.         ... 0.14174958 0.         0.        ]\n",
      " [0.40483308 0.01813794 0.         ... 0.15039128 0.         0.        ]]\n"
     ]
    }
   ],
   "source": [
    "# 模型预测\n",
    "result = model.predict(test_x, batch_size=32)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4 构建复杂模型\n",
    "\n",
    "### 4.1 函数式API\n",
    "tf.keras.Sequential 模型是层的简单堆叠，无法表示任意模型。使用 Keras的函数式API可以构建复杂的模型拓扑，例如：\n",
    "\n",
    "- 多输入模型，\n",
    "\n",
    "- 多输出模型，\n",
    "\n",
    "- 具有共享层的模型（同一层被调用多次），\n",
    "\n",
    "- 具有非序列数据流的模型（例如，残差连接）。\n",
    "\n",
    "**使用函数式 API 构建的模型具有以下特征：**\n",
    "\n",
    "- 层实例可调用并返回张量。\n",
    "- 输入张量和输出张量用于定义 tf.keras.Model 实例。\n",
    "- 此模型的训练方式和 Sequential 模型一样。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 1000 samples\n",
      "Epoch 1/5\n",
      "1000/1000 [==============================] - 0s 351us/sample - loss: 13.1064 - accuracy: 0.1080\n",
      "Epoch 2/5\n",
      "1000/1000 [==============================] - 0s 59us/sample - loss: 21.9265 - accuracy: 0.1180\n",
      "Epoch 3/5\n",
      "1000/1000 [==============================] - 0s 68us/sample - loss: 33.9123 - accuracy: 0.1210\n",
      "Epoch 4/5\n",
      "1000/1000 [==============================] - 0s 52us/sample - loss: 52.5335 - accuracy: 0.1190\n",
      "Epoch 5/5\n",
      "1000/1000 [==============================] - 0s 40us/sample - loss: 85.4629 - accuracy: 0.1180\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f0d40696fd0>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_x = tf.keras.Input(shape=(72,))\n",
    "hidden1 = layers.Dense(32, activation='relu')(input_x)\n",
    "hidden2 = layers.Dense(16, activation='relu')(hidden1)\n",
    "pred = layers.Dense(10, activation='softmax')(hidden2)\n",
    "# 构建tf.keras.Model实例\n",
    "model = tf.keras.Model(inputs=input_x, outputs=pred)\n",
    "model.compile(optimizer=tf.keras.optimizers.Adam(0.001),\n",
    "             loss=tf.keras.losses.categorical_crossentropy,\n",
    "             metrics=['accuracy'])\n",
    "model.fit(train_x, train_y, batch_size=32, epochs=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.2 模型子类化\n",
    "可以通过对 tf.keras.Model 进行子类化并定义自己的前向传播来构建完全可自定义的模型。\n",
    "- 在\\__init\\__ 方法中创建层并将它们设置为类实例的属性。\n",
    "- 在\\__call\\__方法中定义前向传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 1000 samples\n",
      "Epoch 1/5\n",
      "1000/1000 [==============================] - 1s 1ms/sample - loss: 15.5352 - accuracy: 0.1130\n",
      "Epoch 2/5\n",
      "1000/1000 [==============================] - 0s 84us/sample - loss: 23.1448 - accuracy: 0.1150\n",
      "Epoch 3/5\n",
      "1000/1000 [==============================] - 0s 75us/sample - loss: 31.5394 - accuracy: 0.1000\n",
      "Epoch 4/5\n",
      "1000/1000 [==============================] - 0s 70us/sample - loss: 39.0555 - accuracy: 0.1040\n",
      "Epoch 5/5\n",
      "1000/1000 [==============================] - 0s 77us/sample - loss: 45.8000 - accuracy: 0.1010\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f0d400d8dd8>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MyModel(tf.keras.Model):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(MyModel, self).__init__(name='my_model')\n",
    "        self.num_classes = num_classes\n",
    "        # 定义网络层\n",
    "        self.layer1 = layers.Dense(32, activation='relu')\n",
    "        self.layer2 = layers.Dense(num_classes, activation='softmax')\n",
    "    def call(self, inputs):\n",
    "        # 定义前向传播\n",
    "        h1 = self.layer1(inputs)\n",
    "        out = self.layer2(h1)\n",
    "        return out\n",
    "    \n",
    "    def compute_output_shape(self, input_shape):\n",
    "        # 计算输出shape\n",
    "        shape = tf.TensorShape(input_shape).as_list()\n",
    "        shape[-1] = self.num_classes\n",
    "        return tf.TensorShape(shape)\n",
    "# 实例化模型类，并训练\n",
    "model = MyModel(num_classes=10)\n",
    "model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),\n",
    "             loss=tf.keras.losses.categorical_crossentropy,\n",
    "             metrics=['accuracy'])\n",
    "\n",
    "model.fit(train_x, train_y, batch_size=16, epochs=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.3 自定义层\n",
    "通过对 tf.keras.layers.Layer 进行子类化并实现以下方法来创建自定义层：\n",
    "\n",
    "- \\__init\\__: (可选)定义该层要使用的子层\n",
    "- build：创建层的权重。使用 add_weight 方法添加权重。\n",
    "\n",
    "- call：定义前向传播。\n",
    "\n",
    "- compute_output_shape：指定在给定输入形状的情况下如何计算层的输出形状。\n",
    "- 可选，可以通过实现 get_config 方法和 from_config 类方法序列化层。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 1000 samples\n",
      "Epoch 1/5\n",
      "1000/1000 [==============================] - 0s 258us/sample - loss: 11.5199 - accuracy: 0.0860\n",
      "Epoch 2/5\n",
      "1000/1000 [==============================] - 0s 67us/sample - loss: 11.5205 - accuracy: 0.0870\n",
      "Epoch 3/5\n",
      "1000/1000 [==============================] - 0s 69us/sample - loss: 11.5205 - accuracy: 0.0850\n",
      "Epoch 4/5\n",
      "1000/1000 [==============================] - 0s 72us/sample - loss: 11.5201 - accuracy: 0.0830\n",
      "Epoch 5/5\n",
      "1000/1000 [==============================] - 0s 67us/sample - loss: 11.5201 - accuracy: 0.0810\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f0d242d8320>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MyLayer(layers.Layer):\n",
    "    def __init__(self, output_dim, **kwargs):\n",
    "        self.output_dim = output_dim\n",
    "        super(MyLayer, self).__init__(**kwargs)\n",
    "    \n",
    "    def build(self, input_shape):\n",
    "        shape = tf.TensorShape((input_shape[1], self.output_dim))\n",
    "        self.kernel = self.add_weight(name='kernel1', shape=shape,\n",
    "                                   initializer='uniform', trainable=True)\n",
    "        super(MyLayer, self).build(input_shape)\n",
    "    \n",
    "    def call(self, inputs):\n",
    "        return tf.matmul(inputs, self.kernel)\n",
    "\n",
    "    def compute_output_shape(self, input_shape):\n",
    "        shape = tf.TensorShape(input_shape).as_list()\n",
    "        shape[-1] = self.output_dim\n",
    "        return tf.TensorShape(shape)\n",
    "\n",
    "    def get_config(self):\n",
    "        base_config = super(MyLayer, self).get_config()\n",
    "        base_config['output_dim'] = self.output_dim\n",
    "        return base_config\n",
    "\n",
    "    @classmethod\n",
    "    def from_config(cls, config):\n",
    "        return cls(**config)\n",
    "\n",
    "# 使用自定义网络层构建模型\n",
    "model = tf.keras.Sequential(\n",
    "[\n",
    "    MyLayer(10),\n",
    "    layers.Activation('softmax')\n",
    "])\n",
    "\n",
    "\n",
    "model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),\n",
    "             loss=tf.keras.losses.categorical_crossentropy,\n",
    "             metrics=['accuracy'])\n",
    "\n",
    "model.fit(train_x, train_y, batch_size=16, epochs=5)\n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.3 回调\n",
    "回调是传递给模型以自定义和扩展其在训练期间的行为的对象。我们可以编写自己的自定义回调，或使用tf.keras.callbacks中的内置函数，常用内置回调函数如下：\n",
    "\n",
    "- tf.keras.callbacks.ModelCheckpoint：定期保存模型的检查点。\n",
    "- tf.keras.callbacks.LearningRateScheduler：动态更改学习率。\n",
    "- tf.keras.callbacks.EarlyStopping：验证性能停止提高时进行中断培训。\n",
    "- tf.keras.callbacks.TensorBoard：使用TensorBoard监视模型的行为 。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 1000 samples, validate on 200 samples\n",
      "Epoch 1/5\n",
      "1000/1000 [==============================] - 0s 142us/sample - loss: 11.5200 - accuracy: 0.0870 - val_loss: 11.7025 - val_accuracy: 0.0750\n",
      "Epoch 2/5\n",
      "1000/1000 [==============================] - 0s 88us/sample - loss: 11.5206 - accuracy: 0.0900 - val_loss: 11.7015 - val_accuracy: 0.0950\n",
      "Epoch 3/5\n",
      "1000/1000 [==============================] - 0s 88us/sample - loss: 11.5197 - accuracy: 0.0810 - val_loss: 11.7022 - val_accuracy: 0.1100\n",
      "Epoch 4/5\n",
      "1000/1000 [==============================] - 0s 91us/sample - loss: 11.5198 - accuracy: 0.0810 - val_loss: 11.7029 - val_accuracy: 0.1250\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f0d086590f0>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "callbacks = [\n",
    "    tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),\n",
    "    tf.keras.callbacks.TensorBoard(log_dir='./logs')\n",
    "]\n",
    "model.fit(train_x, train_y, batch_size=16, epochs=5,\n",
    "         callbacks=callbacks, validation_data=(val_x, val_y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5 模型保存与恢复\n",
    "### 5.1 权重保存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "layers.Dense(64, activation='relu', input_shape=(32,)),  # 需要有input_shape\n",
    "layers.Dense(10, activation='softmax')])\n",
    "\n",
    "model.compile(optimizer=tf.keras.optimizers.Adam(0.001),\n",
    "              loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 权重保存与重载\n",
    "model.save_weights('./weights/model')\n",
    "model.load_weights('./weights/model')\n",
    "# 保存为h5格式\n",
    "model.save_weights('./model.h5', save_format='h5')\n",
    "model.load_weights('./model.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 保存网络结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'backend': 'tensorflow',\n",
      " 'class_name': 'Sequential',\n",
      " 'config': {'layers': [{'class_name': 'Dense',\n",
      "                        'config': {'activation': 'relu',\n",
      "                                   'activity_regularizer': None,\n",
      "                                   'batch_input_shape': [None, 32],\n",
      "                                   'bias_constraint': None,\n",
      "                                   'bias_initializer': {'class_name': 'Zeros',\n",
      "                                                        'config': {}},\n",
      "                                   'bias_regularizer': None,\n",
      "                                   'dtype': 'float32',\n",
      "                                   'kernel_constraint': None,\n",
      "                                   'kernel_initializer': {'class_name': 'GlorotUniform',\n",
      "                                                          'config': {'seed': None}},\n",
      "                                   'kernel_regularizer': None,\n",
      "                                   'name': 'dense_23',\n",
      "                                   'trainable': True,\n",
      "                                   'units': 64,\n",
      "                                   'use_bias': True}},\n",
      "                       {'class_name': 'Dense',\n",
      "                        'config': {'activation': 'softmax',\n",
      "                                   'activity_regularizer': None,\n",
      "                                   'bias_constraint': None,\n",
      "                                   'bias_initializer': {'class_name': 'Zeros',\n",
      "                                                        'config': {}},\n",
      "                                   'bias_regularizer': None,\n",
      "                                   'dtype': 'float32',\n",
      "                                   'kernel_constraint': None,\n",
      "                                   'kernel_initializer': {'class_name': 'GlorotUniform',\n",
      "                                                          'config': {'seed': None}},\n",
      "                                   'kernel_regularizer': None,\n",
      "                                   'name': 'dense_24',\n",
      "                                   'trainable': True,\n",
      "                                   'units': 10,\n",
      "                                   'use_bias': True}}],\n",
      "            'name': 'sequential_6'},\n",
      " 'keras_version': '2.2.4-tf'}\n"
     ]
    }
   ],
   "source": [
    "# 序列化成json\n",
    "import json\n",
    "import pprint\n",
    "json_str = model.to_json()\n",
    "pprint.pprint(json.loads(json_str))\n",
    "# 从json中重建模型\n",
    "fresh_model = tf.keras.models.model_from_json(json_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "backend: tensorflow\n",
      "class_name: Sequential\n",
      "config:\n",
      "  layers:\n",
      "  - class_name: Dense\n",
      "    config:\n",
      "      activation: relu\n",
      "      activity_regularizer: null\n",
      "      batch_input_shape: !!python/tuple [null, 32]\n",
      "      bias_constraint: null\n",
      "      bias_initializer:\n",
      "        class_name: Zeros\n",
      "        config: {}\n",
      "      bias_regularizer: null\n",
      "      dtype: float32\n",
      "      kernel_constraint: null\n",
      "      kernel_initializer:\n",
      "        class_name: GlorotUniform\n",
      "        config: {seed: null}\n",
      "      kernel_regularizer: null\n",
      "      name: dense_23\n",
      "      trainable: true\n",
      "      units: 64\n",
      "      use_bias: true\n",
      "  - class_name: Dense\n",
      "    config:\n",
      "      activation: softmax\n",
      "      activity_regularizer: null\n",
      "      bias_constraint: null\n",
      "      bias_initializer:\n",
      "        class_name: Zeros\n",
      "        config: {}\n",
      "      bias_regularizer: null\n",
      "      dtype: float32\n",
      "      kernel_constraint: null\n",
      "      kernel_initializer:\n",
      "        class_name: GlorotUniform\n",
      "        config: {seed: null}\n",
      "      kernel_regularizer: null\n",
      "      name: dense_24\n",
      "      trainable: true\n",
      "      units: 10\n",
      "      use_bias: true\n",
      "  name: sequential_6\n",
      "keras_version: 2.2.4-tf\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 保持为yaml格式  #需要提前安装pyyaml\n",
    "\n",
    "yaml_str = model.to_yaml()\n",
    "print(yaml_str)\n",
    "# 从yaml数据中重新构建模型\n",
    "fresh_model = tf.keras.models.model_from_yaml(yaml_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**注意：子类模型不可序列化，因为其体系结构由call方法主体中的Python代码定义。**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.3 保存整个模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 1000 samples\n",
      "Epoch 1/5\n",
      "1000/1000 [==============================] - 0s 380us/sample - loss: 11.5446 - accuracy: 0.1030\n",
      "Epoch 2/5\n",
      "1000/1000 [==============================] - 0s 43us/sample - loss: 11.5470 - accuracy: 0.1090\n",
      "Epoch 3/5\n",
      "1000/1000 [==============================] - 0s 46us/sample - loss: 11.5488 - accuracy: 0.1090\n",
      "Epoch 4/5\n",
      "1000/1000 [==============================] - 0s 44us/sample - loss: 11.5616 - accuracy: 0.1090\n",
      "Epoch 5/5\n",
      "1000/1000 [==============================] - 0s 43us/sample - loss: 11.5849 - accuracy: 0.1090\n"
     ]
    }
   ],
   "source": [
    "model = tf.keras.Sequential([\n",
    "  layers.Dense(10, activation='softmax', input_shape=(72,)),\n",
    "  layers.Dense(10, activation='softmax')\n",
    "])\n",
    "model.compile(optimizer='rmsprop',\n",
    "              loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "model.fit(train_x, train_y, batch_size=32, epochs=5)\n",
    "# 保存整个模型\n",
    "model.save('all_model.h5')\n",
    "# 导入整个模型\n",
    "model = tf.keras.models.load_model('all_model.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6 将keras用于Estimator\n",
    "Estimator API 用于针对分布式环境训练模型。它适用于一些行业使用场景，例如用大型数据集进行分布式训练并导出模型以用于生产"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W1003 15:52:41.191763 139697242609472 estimator.py:1821] Using temporary folder as model directory: /tmp/tmpliq4yvi6\n"
     ]
    }
   ],
   "source": [
    "model = tf.keras.Sequential([layers.Dense(10,activation='softmax'),\n",
    "                          layers.Dense(10,activation='softmax')])\n",
    "\n",
    "model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),\n",
    "              loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "estimator = tf.keras.estimator.model_to_estimator(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7 Eager execution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Eager execution是一个动态执行的编程环境，它可以立即评估操作。Keras不需要此功能，但它受tf.keras程序支持和对检查程序和调试有用。\n",
    "\n",
    "所有的tf.keras模型构建API都与Eager execution兼容。尽管可以使用Sequential和函数API，但Eager execution有利于模型子类化和构建自定义层：其要求以代码形式编写前向传递的API（而不是通过组装现有层来创建模型的API）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8 多GPU上运行"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "tf.keras模型可使用tf.distribute.Strategy在多个GPU上运行 。该API在多个GPU上提供了分布式培训，几乎无需更改现有代码。\n",
    "\n",
    "当前tf.distribute.MirroredStrategy是唯一受支持的分发策略。MirroredStrategy在单台计算机上使用全缩减进行同步训练来进行图内复制。要使用 distribute.Strategys，请将优化器实例化以及模型构建和编译嵌套在Strategys中.scope()，然后训练模型。\n",
    "\n",
    "以下示例tf.keras.Model在单个计算机上的多GPU分配。\n",
    "\n",
    "首先，在分布式策略范围内定义一个模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W1003 15:52:43.881619 139697242609472 cross_device_ops.py:1209] There is non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_9\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_29 (Dense)             (None, 16)                176       \n",
      "_________________________________________________________________\n",
      "dense_30 (Dense)             (None, 1)                 17        \n",
      "=================================================================\n",
      "Total params: 193\n",
      "Trainable params: 193\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "strategy = tf.distribute.MirroredStrategy()\n",
    "with strategy.scope():\n",
    "    model = tf.keras.Sequential()\n",
    "    model.add(layers.Dense(16, activation='relu', input_shape=(10,)))\n",
    "    model.add(layers.Dense(1, activation='sigmoid'))\n",
    "    optimizer = tf.keras.optimizers.SGD(0.2)\n",
    "    model.compile(loss='binary_crossentropy', optimizer=optimizer)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后像单gpu一样在数据上训练模型即可"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 [==============================] - 1s 42ms/step - loss: 0.7060\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f0cde166860>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = np.random.random((1024, 10))\n",
    "y = np.random.randint(2, size=(1024, 1))\n",
    "x = tf.cast(x, tf.float32)\n",
    "dataset = tf.data.Dataset.from_tensor_slices((x, y))\n",
    "dataset = dataset.shuffle(buffer_size=1024).batch(32)\n",
    "model.fit(dataset, epochs=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
